import os
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision

from train_argument import parser, print_args

import random
import copy

from time import time
from model import net
from utils import *
from Simulator import Simulator
from Split_Data import Non_iid_split, data_stats

def main(args):
    save_folder = args.affix
    
    log_folder = os.path.join(args.log_root, save_folder) #return a new path 
    model_folder = os.path.join(args.model_root, save_folder)

    makedirs(log_folder)
    makedirs(model_folder)


    setattr(args, 'log_folder', log_folder) #setattr(obj, var, val) assign object attribute to its value, just like args.'log_folder' = log_folder
    setattr(args, 'model_folder', model_folder)

    logger = create_logger(log_folder, 'train', 'info')
    print_args(args, logger) #It prints arguments
     
    if args.dataset =='cifar10':
        tr_dataset = torchvision.datasets.CIFAR10(args.data_root,
                                                train=True,
                                                transform=torchvision.transforms.Compose([
                                                torchvision.transforms.Pad(4),
                                                torchvision.transforms.RandomCrop(32),
                                                torchvision.transforms.RandomHorizontalFlip(),
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768]  )]),
                                                download = True)
        
        te_dataset = torchvision.datasets.CIFAR10(args.data_root,
                                                train=False,
                                                transform=torchvision.transforms.Compose([
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768]   )]),
                                                download = True)  
        num_classes = 10
        
        Non_iid_tr_datasets, Non_iid_te_datasets = Non_iid_split(
                num_classes, args.num_clients, tr_dataset, te_dataset, args.alpha)
        
        client_data_counts, client_total_samples = data_stats(Non_iid_tr_datasets, num_classes, args.num_clients)
        client_te_data_counts, client_total_te_samples = data_stats(Non_iid_te_datasets, num_classes, args.num_clients)

        while np.min(client_total_samples) < args.batch_size: #if a batch has only one sample, then we have an error in BN layers
            Non_iid_tr_datasets, Non_iid_te_datasets = Non_iid_split(
            10, args.num_clients, tr_dataset, te_dataset, args.alpha)
            client_data_counts, client_total_samples = data_stats(Non_iid_tr_datasets, 10, args.num_clients)
            client_te_data_counts, client_total_te_samples = data_stats(Non_iid_te_datasets, 10, args.num_clients)

        local_tr_data_loaders = [DataLoader(dataset, num_workers = 0,
                                            batch_size = args.batch_size, 
                                            shuffle = True, drop_last=True)
                        for dataset in Non_iid_tr_datasets]
        local_te_data_loaders = [DataLoader(dataset, num_workers = 0,
                                            batch_size = args.batch_size, 
                                            shuffle = True)
                        for dataset in Non_iid_te_datasets]
        

    elif args.dataset =='cifar100':
        tr_dataset = torchvision.datasets.CIFAR100(args.data_root,
                                                train=True,
                                                transform=torchvision.transforms.Compose([
                                                torchvision.transforms.Pad(4),
                                                torchvision.transforms.RandomCrop(32),
                                                torchvision.transforms.RandomHorizontalFlip(),
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize([0.507, 0.487, 0.441], [0.267, 0.256, 0.276]  )]),
                                                download = True)
        
        te_dataset = torchvision.datasets.CIFAR100(args.data_root,
                                                train=False,
                                                transform=torchvision.transforms.Compose([
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize([0.507, 0.487, 0.441], [0.267, 0.256, 0.276]   )]),
                                                download = True)  
        num_classes = 100
        
        Non_iid_tr_datasets, Non_iid_te_datasets = Non_iid_split(
                num_classes, args.num_clients, tr_dataset, te_dataset, args.alpha)
        
        client_data_counts, client_total_samples = data_stats(Non_iid_tr_datasets, num_classes, args.num_clients)
        client_te_data_counts, client_total_te_samples = data_stats(Non_iid_te_datasets, num_classes, args.num_clients)

        while np.min(client_total_te_samples) == 0: #There should be at least one test sample
                Non_iid_tr_datasets, Non_iid_te_datasets = Non_iid_split(
                num_classes, args.num_clients, tr_dataset, te_dataset, args.alpha)
                client_data_counts, client_total_samples = data_stats(Non_iid_tr_datasets, num_classes, args.num_clients)
                client_te_data_counts, client_total_te_samples = data_stats(Non_iid_te_datasets, num_classes, args.num_clients)                



        local_tr_data_loaders = [DataLoader(dataset, num_workers = 0,
                                            batch_size = args.batch_size, 
                                            shuffle = True)
                        for dataset in Non_iid_tr_datasets]
        local_te_data_loaders = [DataLoader(dataset, num_workers = 0,
                                            batch_size = args.batch_size, 
                                            shuffle = True)
                        for dataset in Non_iid_te_datasets]
    

    print("tr_data counts: ", client_total_samples)
    print("te_data_counts: ", client_total_te_samples) 

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("currrent device: ", device)

    if args.mask == 1:
        if args.model == 'conv4':
            model = net.masked_Conv4_BN().to(device) 
        elif args.model == 'resnet':
            model = net.MaskedWideResNet(depth=args.depth, num_classes = num_classes, widen_factor= args.widen_factor).to(device)
                
        logger.info(model)

        trainer = Simulator(args, logger, local_tr_data_loaders, local_te_data_loaders, device)
        trainer.initialization(copy.deepcopy(model))
        trainer.FL_loop()
        

if __name__ == '__main__':
    args = parser()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    main(args)